{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\envs\\registration.py:555: UserWarning: \u001b[33mWARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.\u001b[0m\n",
      "  logger.warn(\n",
      "c:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\envs\\registration.py:627: UserWarning: \u001b[33mWARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']\u001b[0m\n",
      "  logger.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['NOOP'], ['right'], ['right', 'A'], ['right', 'B'], ['right', 'A', 'B'], ['A'], ['left'], ['left', 'A'], ['left', 'B'], ['left', 'A', 'B'], ['down'], ['up']]\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "import time\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from collections import deque\n",
    "from gym.wrappers import GrayScaleObservation\n",
    "from nes_py.wrappers import JoypadSpace\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT\n",
    "from gym_super_mario_bros.actions import COMPLEX_MOVEMENT\n",
    "\n",
    "# 1. 创建马里奥环境\n",
    "env = gym.make(\"SuperMarioBros-1-1-v0\", apply_api_compatibility=True, render_mode=\"human\")\n",
    "env = JoypadSpace(env, SIMPLE_MOVEMENT)  # 约束动作空间\n",
    "env = GrayScaleObservation(env, keep_dim=True)  # 灰度化\n",
    "print(COMPLEX_MOVEMENT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. 定义 DQN 神经网络\n",
    "class MarioDQN(nn.Module):\n",
    "    def __init__(self, input_shape, num_actions):\n",
    "        super(MarioDQN, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=1),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(64 * 7 * 7, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, num_actions)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        return self.fc(x)\n",
    "\n",
    "# 3. 初始化 DQN 模型\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = MarioDQN((1, 84, 84), len(SIMPLE_MOVEMENT)).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
    "loss_fn = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. 经验回放（Replay Buffer）\n",
    "class ReplayBuffer:\n",
    "    def __init__(self, capacity):\n",
    "        self.buffer = deque(maxlen=capacity)\n",
    "\n",
    "    def push(self, state, action, reward, next_state, done):\n",
    "        self.buffer.append((state, action, reward, next_state, done))\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        return random.sample(self.buffer, batch_size)\n",
    "\n",
    "buffer = ReplayBuffer(10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\utils\\passive_env_checker.py:272: UserWarning: \u001b[33mWARN: No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`.\u001b[0m\n",
      "  logger.warn(\n",
      "c:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 0, Total Reward: 277.0\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Given groups=1, weight of size [32, 1, 8, 8], expected input[1, 240, 256, 1] to have 1 channels, but got 240 channels instead",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 21\u001b[0m\n\u001b[0;32m     19\u001b[0m     state_tensor \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(state, device\u001b[38;5;241m=\u001b[39mdevice)\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m     20\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m---> 21\u001b[0m         action \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate_tensor\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39margmax()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[0;32m     23\u001b[0m \u001b[38;5;66;03m# 执行动作\u001b[39;00m\n\u001b[0;32m     24\u001b[0m next_state, reward, done, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\container.py:215\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m    214\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 215\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m    216\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 460\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\16673\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[1;34m(self, input, weight, bias)\u001b[0m\n\u001b[0;32m    452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m    453\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[0;32m    454\u001b[0m                     weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[0;32m    455\u001b[0m                     _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[1;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    457\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: Given groups=1, weight of size [32, 1, 8, 8], expected input[1, 240, 256, 1] to have 1 channels, but got 240 channels instead"
     ]
    }
   ],
   "source": [
    "# 5. 训练循环\n",
    "epsilon = 1.0  # 探索率\n",
    "gamma = 0.99  # 折扣因子\n",
    "\n",
    "for episode in range(1000):\n",
    "    state, _ = env.reset()\n",
    "    state = np.array(state, dtype=np.float32) / 255.0  # 归一化\n",
    "    state = np.transpose(state, (2, 0, 1))  # Move channel dimension to the front\n",
    "    # Ensure the state has the correct number of channels\n",
    "    if state.shape[0] != 1:\n",
    "        state = np.mean(state, axis=0, keepdims=True)  # Convert to grayscale if needed\n",
    "    done = False\n",
    "    total_reward = 0\n",
    "\n",
    "    while not done:\n",
    "        env.render()\n",
    "        time.sleep(0.01)  # 可以根据需要加速或去除\n",
    "\n",
    "        # ε-greedy 选择动作\n",
    "        if random.random() < epsilon:\n",
    "            action = env.action_space.sample()  # 随机探索\n",
    "        else:\n",
    "            state_tensor = torch.tensor(state, device=device).unsqueeze(0)\n",
    "            with torch.no_grad():\n",
    "                action = model(state_tensor).argmax().item()\n",
    "\n",
    "        # 执行动作\n",
    "        next_state, reward, done, truncated, info = env.step(action)\n",
    "        next_state = np.array(next_state, dtype=np.float32) / 255.0  # 归一化\n",
    "\n",
    "        if action == 1:\n",
    "            total_reward += 1\n",
    "\n",
    "        buffer.push(state, action, reward, next_state, done)\n",
    "        state = next_state\n",
    "        total_reward += reward\n",
    "\n",
    "    # 训练 DQN\n",
    "    if len(buffer.buffer) > 1000:\n",
    "        batch = buffer.sample(32)\n",
    "        state_batch = torch.tensor([b[0] for b in batch], dtype=torch.float32, device=device)\n",
    "        action_batch = torch.tensor([b[1] for b in batch], device=device)\n",
    "        reward_batch = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=device)\n",
    "        next_state_batch = torch.tensor([b[3] for b in batch], dtype=torch.float32, device=device)\n",
    "        done_batch = torch.tensor([b[4] for b in batch], dtype=torch.float32, device=device)\n",
    "\n",
    "        q_values = model(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze()\n",
    "        next_q_values = model(next_state_batch).max(1)[0]\n",
    "        expected_q_values = reward_batch + (1 - done_batch) * gamma * next_q_values\n",
    "\n",
    "        loss = loss_fn(q_values, expected_q_values.detach())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # 逐步降低探索率\n",
    "    epsilon = max(0.1, epsilon * 0.995)\n",
    "\n",
    "    print(f\"Episode {episode}, Total Reward: {total_reward}\")\n",
    "\n",
    "    # 每 50 局保存模型\n",
    "    if episode % 50 == 0:\n",
    "        torch.save(model.state_dict(), \"mario_dqn.pth\")\n",
    "\n",
    "env.close()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
